package com.jstarcraft.ai.jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Random;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.svm.SupportVectorLearner;
import com.jstarcraft.ai.jsat.distributions.Distribution;
import com.jstarcraft.ai.jsat.distributions.LogUniform;
import com.jstarcraft.ai.jsat.distributions.kernels.KernelTrick;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.ListUtils;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

import it.unimi.dsi.fastutil.ints.IntArrayList;

/**
 * An implementation of Conservative Stochastic Kernel Logistic Regression. This
 * is an online algorithm that obtains sparse solutions by conservatively
 * rejecting updates based on a binomial distribution of the error on each
 * update. <br>
 * <br>
 * This algorithm works best on data sets with a very high number of samples
 * where a high accuracy is obtainable using a kernel model. It is often the
 * case that this model produces accurate results, but has a low confidence due
 * to the conservative updating. This can be counteracted by having a very large
 * number of features, but that often increases the size of the model. <br>
 * This batch version can also be used to more efficiently learn dense KLR
 * models using the stochastic method with the {@link CSKLR.UpdateMode#NC} mode
 * if model sparsity is not important. <br>
 * <br>
 * It is important to read the documentation and test some different values for
 * the {@link #setEta(double) learning rate} and {@link #setGamma(double) gamma}
 * variables. They behave different compared to many algorithms. <br>
 * <br>
 * See paper: <br>
 * Zhang, L., Jin, R., Chen, C., Bu, J.,&amp;He, X. (2012). <i>Efficient Online
 * Learning for Large-Scale Sparse Kernel Logistic Regression</i>. Twenty-Sixth
 * AAAI Conference on Artificial Intelligence (pp. 1219–1225). Retrieved from
 * <a href=
 * "http://www.aaai.org/ocs/index.php/AAAI/AAAI12/paper/viewPDFInterstitial/5003/5544">here</a>
 * 
 * @author Edward Raff
 */
public class CSKLRBatch extends SupportVectorLearner implements Parameterized, Classifier {
    private static final long serialVersionUID = -2305532659182911285L;

    private double eta;
    private double curNorm;
    private double R = 10;
    private int T = 0;
    private CSKLR.UpdateMode mode;
    protected double gamma = 2;
    private int epochs = 10;

    /**
     * Creates a new SCKLR Batch learning object
     * 
     * @param eta       the learning rate to use
     * @param kernel    the kernel to use
     * @param R         the maximal norm of the surface
     * @param mode      the mode to use
     * @param cacheMode the kernel caching mode to use
     */
    public CSKLRBatch(double eta, KernelTrick kernel, double R, CSKLR.UpdateMode mode, CacheMode cacheMode) {
        super(kernel, cacheMode);
        setEta(eta);
        setR(R);
        setMode(mode);
    }

    /**
     * Copy constructor
     * 
     * @param toClone the object to copy
     */
    protected CSKLRBatch(CSKLRBatch toClone) {
        super(toClone);

        this.curNorm = toClone.curNorm;
        this.epochs = toClone.epochs;
        this.eta = toClone.eta;
        this.R = toClone.R;
        this.T = toClone.T;
        this.mode = toClone.mode;
        this.gamma = toClone.gamma;

    }

    @Override
    public CSKLRBatch clone() {
        return new CSKLRBatch(this);
    }

    /**
     * Sets the number of training epochs (passes) through the data set
     * 
     * @param epochs the number of passes through the data set
     */
    public void setEpochs(int epochs) {
        this.epochs = epochs;
    }

    /**
     * Returns the number of passes through the data set
     * 
     * @return the number of passes through the data set
     */
    public int getEpochs() {
        return epochs;
    }

    /**
     * Sets the learning rate to use for the algorithm. Unlike many other stochastic
     * algorithms, the learning rate for CSKLR should be large, often in the range
     * of (0.5, 1) - and can even be larger than 1 at times. If the learning rate is
     * too low, it may be difficult to get strong confidence results from the
     * algorithm.
     * 
     * @param eta the positive learning rate to use
     */
    public void setEta(double eta) {
        if (eta < 0 || Double.isNaN(eta) || Double.isInfinite(eta))
            throw new IllegalArgumentException("The learning rate should be in (0, Inf), not " + eta);
        this.eta = eta;
    }

    /**
     * Returns the learning rate to use
     * 
     * @return the learning rate to use
     */
    public double getEta() {
        return eta;
    }

    /**
     * Sets the maximal margin norm value for the algorithm. When the norm is
     * exceeded, the coefficients will be rescaled to fit in the norm. If the
     * maximal norm is too small (less than 5), it may be difficult to get strong
     * confidence results from the algorithm. <br>
     * A good range of values suggested by the original paper is 10<sup>x</sup>
     * &forall; x &isin; {0, 1, 2, 3, 4, 5}
     * 
     * @param R
     */
    public void setR(double R) {
        if (R < 0 || Double.isNaN(R) || Double.isInfinite(R))
            throw new IllegalArgumentException("The max norm should be in (0, Inf), not " + R);
        this.R = R;
    }

    /**
     * Returns the maximal norm of the algorithm
     * 
     * @return the maximal norm of the algorithm
     */
    public double getR() {
        return R;
    }

    /**
     * Sets what update mode should be used. The update mode controls the sparsity
     * of the mode, and the behavior of {@link #setGamma(double) }
     * 
     * @param mode the update mode to use
     */
    public void setMode(CSKLR.UpdateMode mode) {
        this.mode = mode;
    }

    /**
     * Returns the update mode in use
     * 
     * @return the update mode in use
     */
    public CSKLR.UpdateMode getMode() {
        return mode;
    }

    /**
     * Sets the gamma value to use. This value, depending on which
     * {@link CSKLR.UpdateMode} is used, controls the sparsity of the model.
     * 
     * @param gamma the gamma parameter, which is at least always positive
     */
    public void setGamma(double gamma) {
        if (gamma < 0 || Double.isNaN(gamma) || Double.isInfinite(gamma))
            throw new IllegalArgumentException("Gamma must be in (0, Infity), not " + gamma);
        this.gamma = gamma;
    }

    /**
     * Returns the gamma sparsity parameter value
     * 
     * @return the gamma sparsity parameter value
     */
    public double getGamma() {
        return gamma;
    }

    /**
     * Guesses the distribution to use for the R parameter
     *
     * @param d the dataset to get the guess for
     * @return the guess for the R parameter
     * @see #setR(double)
     */
    public static Distribution guessR(DataSet d) {
        return new LogUniform(1, 1e5);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);

        double p_0 = CSKLR.getScore(-1, getPreScore(data.getNumericalValues()));

        cr.setProb(0, p_0);
        cr.setProb(1, 1 - p_0);

        return cr;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        if (dataSet.getClassSize() != 2)
            throw new FailedToFitException("CSKLR supports only binary classification");
        // First we need to set up the vectors array

        final int N = dataSet.size();
        vecs = new ArrayList<Vec>(N);
        alphas = new double[N];
        for (int i = 0; i < N; i++)
            vecs.add(dataSet.getDataPoint(i).getNumericalValues());

        curNorm = 0;
        T = 0;
        Random rand = RandomUtil.getRandom();

        IntArrayList sampleOrder = new IntArrayList(N);
        ListUtils.addRange(sampleOrder, 0, N, 1);

        setCacheMode(getCacheMode());// Initiates the cahce

        for (int epoch = 0; epoch < epochs; epoch++) {
            Collections.shuffle(sampleOrder);

            for (int i : sampleOrder) {
                final double weight = dataSet.getWeight(i);
                final double y_t = dataSet.getDataPointCategory(i) * 2 - 1;
                final Vec x_t = vecs.get(i);
                final double pre = getPreScore(x_t);
                final double score = CSKLR.getScore(y_t, pre);

                switch (mode) {
                case NC:
                    break;
                default:
                    double pt = mode.pt(y_t, score, pre, eta, gamma);
                    if (rand.nextDouble() > pt)
                        continue;
                    break;
                }

                double alpha_i = -eta * y_t * mode.grad(y_t, score, pre, gamma) * weight;

                alphas[i] += alpha_i;

                curNorm += Math.abs(alpha_i) * kEval(i, i);

                // projection step
                if (curNorm > R) {
                    double coef = R / curNorm;
                    for (int j = 0; j < alphas.length; j++)
                        alphas[j] *= coef;
                    curNorm = coef;
                }
            }

        }

        int supportVectorCount = 0;
        for (int i = 0; i < N; i++)
            if (alphas[i] > 0 || alphas[i] < 0)// Its a support vector
            {
                ListUtils.swap(vecs, supportVectorCount, i);
                alphas[supportVectorCount++] = alphas[i];
            }
        vecs = new ArrayList<Vec>(vecs.subList(0, supportVectorCount));
        alphas = Arrays.copyOfRange(alphas, 0, supportVectorCount);

        setCacheMode(null);
        setAlphas(alphas);
    }

    private double getPreScore(Vec x) {
        return kEvalSum(x);
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

}
